refactor!: Add DualConeProjector#678
Conversation
|
Need to make these projectors public and documented. I think the public linalg package would be a nice fit, but maybe an aggregation subpackage would also be cool (torchjd.aggregation.dual_cone_projection). |
ValerianRey
left a comment
There was a problem hiding this comment.
Need changelog entry. This will be a breaking change especially if we move the norm_eps and reg_eps params to the projector. Otherwise it's breaking just for people specifying the solver (still breaking but to very few people).
|
|
||
| class DualConeProjector(ABC): | ||
| @abstractmethod | ||
| def project_weights(self, U: Tensor, G: PSDMatrix) -> Tensor: |
|
|
||
| def projector_or_default(projector: DualConeProjector | None) -> DualConeProjector: | ||
| if projector is None: | ||
| return QPSolverBased("quadprog") |
There was a problem hiding this comment.
I think quadprog should be a subclass of QPSolverBased.
If we don't do that, we'll be unable to use solver-specific extra parameters.
|
|
||
| def forward(self, gramian: PSDMatrix, /) -> Tensor: | ||
| u = self.weighting(gramian) | ||
| G = regularize(normalize(gramian, self.norm_eps), self.reg_eps) |
There was a problem hiding this comment.
I think the regularization and normalization should become part of the projector, because the requiered amount of regularization or projection may vary per solver. Norm_eps and reg_eps should thus also be given to the projector directly I think.
DualConeProjectorDualConeProjector
DualConeProjectorDualConeProjector
No description provided.